{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Neural Networks\n",
    "===============\n",
    "\n",
    "This tutorial is based on [Deep Learning with PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html)\n",
    "\n",
    "In this tutorial, we want to develop a classifier to classify digit images (MNIST dataset):\n",
    "\n",
    "![classifier](fig/classifier.png)\n",
    "\n",
    "We use the famous **LeNet5** network:\n",
    "\n",
    "![LeNet](fig/LeNet.png)\n",
    "\n",
    "It is a simple feed-forward network. It takes the input, feeds it\n",
    "through several layers one after the other, and then finally gives the\n",
    "output.\n",
    "\n",
    "About convolutional layer, see the animation below.\n",
    "\n",
    "<table><tr>\n",
    "    <td><img src=\"https://www.cc.gatech.edu/~san37/img/dl/conv.gif\" border=0></td>\n",
    "    <td><img src=\"https://upload.wikimedia.org/wikipedia/commons/4/4f/3D_Convolution_Animation.gif\" border=0></td>\n",
    "</tr></table>\n",
    "\n",
    "A typical training procedure for a neural network is as follows:\n",
    "\n",
    "![](fig/torch_flow.png)\n",
    "- Define the neural network that has some learnable **parameters** (or\n",
    "  weights)\n",
    "- Iterate over a dataset of inputs\n",
    "- Process input through the network\n",
    "- Compute the loss (how far is the output from being correct)\n",
    "- Propagate gradients back into the network’s parameters\n",
    "- Update the weights of the network, typically using a simple update rule:\n",
    "  $\\theta = \\theta - \\alpha \\nabla_\\theta \\mathcal{L}(\\theta)$\n",
    "  * `learning_rate` $\\alpha$ here is a kind of **hyperparameter**\n",
    "\n",
    "Define the network\n",
    "------------------\n",
    "\n",
    "Let’s define this network:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class LeNet(nn.Module): # inherent from nn.Module!\n",
    "\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        self.features = nn.Sequential( # 3*28*28\n",
    "            # in_channel, out_channel, kernel_size\n",
    "            nn.Conv2d(1, 6, 5), # 6*26*26\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(2, 2), # 6*13*13\n",
    "            nn.Conv2d(6, 16, 5), # 16*7*7\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.MaxPool2d(2, 2) # 16*4*4\n",
    "        )\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(16 * 4 * 4, 120),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(120, 84),\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.Linear(84, 10)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = x.view(-1, 16 * 4 * 4)\n",
    "        x = self.classifier(x)\n",
    "        return x\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # in_chan, out_chan, kernel_size\n",
    "        self.conv1 = nn.Conv2d(1, 8, 3, stride=1)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.fc = nn.Linear(26 * 26 * 8, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        output = self.conv1(x)\n",
    "        output = self.relu(output)\n",
    "        output = output.view(output.shape[0],-1)\n",
    "        output = self.fc(output)\n",
    "        return output\n",
    "\n",
    "net = Net()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You just have to define the ``forward`` function, and the ``backward``\n",
    "function (where gradients are computed) is automatically defined for you\n",
    "using ``autograd``.\n",
    "You can use any of the Tensor operations in the ``forward`` function.\n",
    "\n",
    "The learnable parameters of a model are returned by ``net.parameters()``"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loss Function\n",
    "-------------\n",
    "A loss function takes the (output, target) pair of inputs, and computes a\n",
    "value that estimates how far away the output is from the target.\n",
    "\n",
    "There are several different\n",
    "[loss functions](https://pytorch.org/docs/nn.html#loss-functions) under the\n",
    "nn package .\n",
    "A simple loss is: ``nn.MSELoss`` which computes the mean-squared error\n",
    "between the input and the target.\n",
    "But for classification problem, we commonly use [``nn.CrossEntropyLoss``](https://pytorch.org/docs/stable/nn.html#crossentropyloss).\n",
    "\n",
    "(**Notice**: In PyTorch's implementation, **CrossEntropyLoss = Softmax + NLLLoss**, thus no need for softmax in the last layer of NN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Later we can define loss as\n",
    "\n",
    "``loss = criterion(output, target)``\n",
    "\n",
    "then call `loss.backward()` to do the differentiation automatically.\n",
    "\n",
    "If you follow ``loss`` in the backward direction, using its\n",
    "``.grad_fn`` attribute, you will see a graph of computations that looks\n",
    "like this:\n",
    "\n",
    "```python\n",
    "    input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d\n",
    "          -> view -> linear -> relu -> linear -> relu -> linear\n",
    "          -> MSELoss\n",
    "          -> loss\n",
    "          \n",
    "```\n",
    "\n",
    "So, when we call ``loss.backward()``, the whole graph is differentiated\n",
    "w.r.t. the loss, and all Tensors in the graph that has ``requires_grad=True``\n",
    "will have their ``.grad`` Tensor accumulated with the gradient."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimizer\n",
    "--------\n",
    "\n",
    "In order to update the weights/parameters following the rule below, we need to define an optimizer.\n",
    "\n",
    "``weight = weight - learning_rate * gradient``\n",
    "\n",
    "Here we use Stochastic Gradient Descent (SGD):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LEARNING_RATE = 0.1\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Other useful optimizer including Nesterov-SGD, Adam, RMSProp, etc., can be found in ``torch.optim``."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset\n",
    "\n",
    "Now we can load the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset and prepare for training.\n",
    "![MNIST](https://upload.wikimedia.org/wikipedia/commons/thumb/2/27/MnistExamples.png/220px-MnistExamples.png)\n",
    "\n",
    "We need not download ourselves, PyTorch has encapsulated common datasets in its package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision import datasets, transforms\n",
    "\n",
    "BATCH_SIZE = 64 # Mini-batch size\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307, ), (0.3081, ))\n",
    "])\n",
    "\n",
    "train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)\n",
    "test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's train!\n",
    "\n",
    "Firstly define hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 20\n",
    "# LEARNING_RATE\n",
    "# BATCH_SIZE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And set your training device. If you use GPU, set `DEVICE = torch.device('cuda')`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEVICE = torch.device('cpu') # cuda\n",
    "net = net.to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then compose what we have coded together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.train() # set the network in train mode\n",
    "for epoch_idx in range(NUM_EPOCHS):\n",
    "    for batch_idx, (data, target) in enumerate(train_dataloader): # train in minibatch\n",
    "        # get (x_i, y_i)\n",
    "        # be careful of their shape\n",
    "        # data: (N, channels, height, width) (20,1,28,28)\n",
    "        # target: (N, )\n",
    "        data, target = data.to(DEVICE), target.to(DEVICE)\n",
    "\n",
    "        optimizer.zero_grad() # first zero_grad\n",
    "        output = net(data) # forward\n",
    "\n",
    "        loss = criterion(output, target) # calculate loss\n",
    "        \n",
    "        loss.backward() # backward\n",
    "        optimizer.step() # update parameters\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But wait, how can we know the performance of the network?\n",
    "\n",
    "We should add the evaluation/inference module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model_eval, loader_eval, criterion_eval):\n",
    "    model_eval.eval() # set the network in evaluation mode\n",
    "    loss_eval = 0\n",
    "    correct = 0.\n",
    "    with torch.no_grad():\n",
    "        for data, target in loader_eval:\n",
    "            data, target = data.to(DEVICE), target.to(DEVICE)\n",
    "            output = net(data)\n",
    "            loss_eval += criterion_eval(output, target).item()\n",
    "\n",
    "            pred = output.argmax(dim=1, keepdim=True)\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    loss_eval = loss_eval / loader_eval.dataset.__len__()\n",
    "    accuracy = correct / loader_eval.dataset.__len__()\n",
    "    response = {'loss': loss_eval, 'acc': accuracy}\n",
    "    return response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then redefine the training process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.train()\n",
    "for epoch_idx in range(NUM_EPOCHS):\n",
    "    for batch_idx, (data, target) in enumerate(train_dataloader):\n",
    "        data, target = data.to(DEVICE), target.to(DEVICE)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        output = net(data)\n",
    "\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # add evaluation below\n",
    "    train_resp = evaluate(net, train_dataloader, criterion)\n",
    "    eval_resp = evaluate(net, test_dataloader, criterion)\n",
    "\n",
    "    print ('-*-*-*-*-*- Epoch {} -*-*-*-*-*-'.format(epoch_idx))\n",
    "    print ('Train Loss: {:.6f}\\t'.format(train_resp['loss']))\n",
    "    print ('Train Acc: {:.6f}\\t'.format(train_resp['acc']))\n",
    "    print ('Eval Loss: {:.6f}\\t'.format(eval_resp['loss']))\n",
    "    print ('Eval Acc: {:.6f}\\t'.format(eval_resp['acc']))\n",
    "    print ('\\n')\n",
    "    torch.save(net, 'count.pth') # save model each epoch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To better see the network training procedure, we can use the `tqdm` package (install it by pip)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def evaluate(model_eval, loader_eval, criterion_eval):\n",
    "    model_eval.eval() # set the network in evaluation mode\n",
    "    loss_eval = 0\n",
    "    correct = 0.\n",
    "    pbar = tqdm(total = len(loader_eval), desc='Evaluation', ncols=100)\n",
    "    with torch.no_grad():\n",
    "        for data, target in loader_eval:\n",
    "            data, target = data.to(DEVICE), target.to(DEVICE)\n",
    "            output = net(data)\n",
    "            loss_eval += criterion_eval(output, target).item()\n",
    "\n",
    "            pred = output.argmax(dim=1, keepdim=True)\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "            pbar.update(1)\n",
    "    pbar.close()\n",
    "\n",
    "    loss_eval = loss_eval / loader_eval.dataset.__len__()\n",
    "    accuracy = correct / loader_eval.dataset.__len__()\n",
    "    response = {'loss': loss_eval, 'acc': accuracy}\n",
    "    return response\n",
    "\n",
    "net.train()\n",
    "for epoch_idx in range(NUM_EPOCHS):\n",
    "    pbar = tqdm(total = len(train_dataloader), desc='Train - Epoch {}'.format(epoch_idx), ncols=100)\n",
    "    for batch_idx, (data, target) in enumerate(train_dataloader):\n",
    "        data, target = data.to(DEVICE), target.to(DEVICE)\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        output = net(data)\n",
    "\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        pbar.update(1)\n",
    "    pbar.close()\n",
    "\n",
    "    # add evaluation below\n",
    "    train_resp = evaluate(net, train_dataloader, criterion)\n",
    "    eval_resp = evaluate(net, test_dataloader, criterion)\n",
    "\n",
    "    print ('-*-*-*-*-*- Epoch {} -*-*-*-*-*-'.format(epoch_idx))\n",
    "    print ('Train Loss: {:.6f}\\t'.format(train_resp['loss']))\n",
    "    print ('Train Acc: {:.6f}\\t'.format(train_resp['acc']))\n",
    "    print ('Eval Loss: {:.6f}\\t'.format(eval_resp['loss']))\n",
    "    print ('Eval Acc: {:.6f}\\t'.format(eval_resp['acc']))\n",
    "    print ('\\n')\n",
    "    torch.save(net, 'count.pth') # save model each epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
